import os

import numpy as np
import pandas as pd
from scipy.cluster.vq import kmeans, whiten
from sklearn.utils import resample
from sklift.datasets import fetch_criteo


def compute_uplift(treat, y, indices):
    y_selected = y[indices]
    treat_selected = treat[indices]
    treated = np.mean(y_selected[treat_selected == 1])
    baseline = np.mean(y_selected[treat_selected == 0])
    uplift = treated - baseline
    return uplift, treated, baseline


if __name__ == '__main__':

    dataset = fetch_criteo(target_col='visit',
                           treatment_col='treatment',
                           data_home='datasets')
    X = np.array(dataset.data.values)
    treat = np.array(dataset.treatment.values)
    visit = np.array(dataset.target.values)

    stratify_cols = pd.concat([dataset.treatment, dataset.target], axis=1)
    n_samples = 100000
    n_clusters = 20

    X_sampled, treat_sampled, y_sampled = resample(
        X,
        dataset.treatment.astype(bool),
        dataset.target.astype(int),
        n_samples=n_samples,
        stratify=stratify_cols,
        replace=False,
        random_state=20
    )

    whitened = whiten(X_sampled)
    kmeans20 = kmeans(whitened, n_clusters, seed=42)

    belongs = []
    for sample in whitened:
        distances = np.linalg.norm(sample-kmeans20[0], axis=1)
        belongs.append(np.argmin(distances))

    uniques = np.unique(belongs, return_counts=True)
    belongs = np.array(belongs)

    uplifts = []
    treateds = []
    baselines = []

    for k in range(n_clusters):
        uplift, treated, baseline = compute_uplift(treat_sampled, y_sampled, belongs == k)
        uplifts.append(uplift)
        treateds.append(treated)
        baselines.append(baseline)

    treateds_repeated = np.repeat(treateds, uniques[1])
    baselines_repeated = np.repeat(baselines, uniques[1])
    cluster_sizes = uniques[1]

    save_dir = '../save/bernoulli_criteo/data/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    np.save(os.path.join(save_dir, 'treated_means.npy'), treateds_repeated)
    np.save(os.path.join(save_dir, 'baseline_means.npy'), baselines_repeated)
    np.save(os.path.join(save_dir, 'cluster_sizes.npy'), cluster_sizes)
